import torch
from w2v_costh import forward


x_norm_fwd = torch.randn(23, 768).cuda(0)
w_fwd = torch.randn(768, 1211).cuda(0)

fwd_res = forward(x_norm_fwd, w_fwd)

print(f"Length of forward return value is {len(fwd_res)}")
print("The shape of forward return value[0] is: ", fwd_res[0].shape, fwd_res[0].device)
print("The shape of forward return value[1] is: ", fwd_res[1].shape, fwd_res[1].device)
print("The shape of forward return value[2] is: ", fwd_res[2].shape, fwd_res[2].device)

